#Resnet train and validation
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.datasets import ImageFolder
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoImageProcessor
import os
from tqdm import tqdm
import numpy as np
# Config parameters
MODEL_CHOICE = 'resnet50'
MODEL_CHECKPOINTS = {
'resnet50': 'microsoft/resnet-50'
}
MODEL_CHECKPOINT = MODEL_CHECKPOINTS[MODEL_CHOICE]
DATA_DIR = 'newarchive'
DATA_USAGE_RATIO = 1
NUM_EPOCHS = 40
BATCH_SIZE = 32
LEARNING_RATE = 2e-5
VAL_SPLIT = 0.2
TEST_SPLIT = 0.1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Data loading and transforms
class_names = sorted([d.name for d in os.scandir(DATA_DIR) if d.is_dir()])
num_labels = len(class_names)
label2id = {name: i for i, name in enumerate(class_names)}
id2label = {i: name for i, name in enumerate(class_names)}
print(f"Found {num_labels} classes: {class_names}")
image_processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)
size = image_processor.size["shortest_edge"]
mean = image_processor.image_mean
std = image_processor.image_std
# Train and validation
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(size, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
val_transforms = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
class TransformedDataset(Dataset):
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, index):
x, y = self.subset[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.subset)
full_dataset = ImageFolder(root=DATA_DIR)
print(f"Full dataset size: {len(full_dataset)}")
test_size = int(len(full_dataset) * TEST_SPLIT)
train_val_size = len(full_dataset) - test_size
generator = torch.Generator().manual_seed(42)
dataset_to_split, test_subset = torch.utils.data.random_split(
full_dataset, [train_val_size, test_size], generator=generator
)
total_size = len(dataset_to_split)
val_size = int(total_size * VAL_SPLIT)
train_size = total_size - val_size
train_subset, val_subset = torch.utils.data.random_split(dataset_to_split, [train_size, val_size], generator=torch.Generator().manual_seed(42))
train_dataset = TransformedDataset(train_subset, transform=train_transforms)
val_dataset = TransformedDataset(val_subset, transform=val_transforms)
test_dataset = TransformedDataset(test_subset, transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
print(f"\nOriginal dataset has been split into:")
print(f" - Training dataset size: {len(train_dataset)}")
print(f" - Validation dataset size: {len(val_dataset)}")
print(f" - Test dataset size: {len(test_dataset)}\n")
# Loading pre-trained model
model = AutoModelForImageClassification.from_pretrained(
MODEL_CHECKPOINT,
num_labels=num_labels,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True
)
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
best_val_accuracy = 0.0
best_model_save_path = f'./{MODEL_CHOICE}_best_model'
# Training and validation epoches
for epoch in range(NUM_EPOCHS):
print(f"\n--- Epoch {epoch + 1}/{NUM_EPOCHS} ---")
model.train()
train_loss = 0.0
train_corrects = 0
train_total = 0
pbar_train = tqdm(train_loader, desc="Training...")
for inputs, labels in pbar_train:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(pixel_values=inputs, labels=labels)
loss = outputs.loss
logits = outputs.logits
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
_, preds = torch.max(logits, 1)
train_corrects += torch.sum(preds == labels.data)
train_total += labels.size(0)
pbar_train.set_postfix({'loss': f'{loss.item():.4f}'})
epoch_train_loss = train_loss / train_total
epoch_train_acc = train_corrects.double() / train_total
#Validation
model.eval()
val_loss = 0.0
val_corrects = 0
val_total = 0
with torch.no_grad():
pbar_val = tqdm(val_loader, desc="Validation:")
for inputs, labels in pbar_val:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(pixel_values=inputs, labels=labels)
loss = outputs.loss
logits = outputs.logits
val_loss += loss.item() * inputs.size(0)
_, preds = torch.max(logits, 1)
val_corrects += torch.sum(preds == labels.data)
val_total += labels.size(0)
pbar_val.set_postfix({'loss': f'{loss.item():.4f}'})
epoch_val_loss = val_loss / val_total
epoch_val_acc = val_corrects.double() / val_total
print(f"Epoch {epoch + 1} Result: ")
print(f" Train - Loss: {epoch_train_loss:.4f}, Accuracy: {epoch_train_acc:.4f}")
print(f" Val - Loss: {epoch_val_loss:.4f}, Accuracy: {epoch_val_acc:.4f}")
if epoch_val_acc > best_val_accuracy:
best_val_accuracy = epoch_val_acc
print(f"\nFound the best model! Val Accuracy: {best_val_accuracy:.4f}")
if not os.path.exists(best_model_save_path):
os.makedirs(best_model_save_path)
print(f"Saved the best model to: {best_model_save_path}")
model.save_pretrained(best_model_save_path)
image_processor.save_pretrained(best_model_save_path)
print("Saved successfully!")
print("\n--- Train finished ---")
print(f"Best model saved to: {best_model_save_path}")
print(f"Best model accuracy: {best_val_accuracy:.4f}")
# Test accuracy + Confusion Matrix
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
from PIL import Image
import shutil
from transformers import AutoModelForImageClassification, AutoImageProcessor
import os
from tqdm import tqdm
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# CONFIG
MODEL_CHOICE = 'resnet50'
BEST_MODEL_PATH = f'./{MODEL_CHOICE}_best_model'
DATA_DIR = 'newarchive'
BATCH_SIZE = 32
TEST_SPLIT = 0.1
OUTPUT_DIR = 'sorted_predictions'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Dataset Class
class TransformedDataset(Dataset):
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, index):
x, y = self.subset[index]
original_idx_in_full_dataset = self.subset.indices[index]
path = self.subset.dataset.samples[original_idx_in_full_dataset][0]
if self.transform:
x = self.transform(x)
return x, y, path
def __len__(self):
return len(self.subset)
# --- DATA LOADING AND TRANSFORMS (No changes here) ---
print(f"Loading Image Processor from '{BEST_MODEL_PATH}'")
image_processor = AutoImageProcessor.from_pretrained(BEST_MODEL_PATH)
size = image_processor.size.get('shortest_edge', 224)
mean = image_processor.image_mean
std = image_processor.image_std
test_transforms = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
print("\nLoading and dividing the dataset...")
full_dataset = ImageFolder(root=DATA_DIR)
test_size = int(len(full_dataset) * TEST_SPLIT)
train_val_size = len(full_dataset) - test_size
generator = torch.Generator().manual_seed(42)
_, test_subset = torch.utils.data.random_split(
full_dataset, [train_val_size, test_size], generator=generator
)
test_dataset = TransformedDataset(test_subset, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
print(f"Dataset Size for prediction: {len(test_dataset)}")
# Model Loading
print(f"\nLoading the best model from '{BEST_MODEL_PATH}'")
model = AutoModelForImageClassification.from_pretrained(BEST_MODEL_PATH)
model.to(device)
model.eval()
# Get class names from the model's configuration
class_names = [model.config.id2label[i] for i in range(len(model.config.id2label))]
print(f"\nFound {len(class_names)} classes: {class_names}")
if os.path.exists(OUTPUT_DIR):
print(f"Directory '{OUTPUT_DIR}' already exists. Removing it to start fresh.")
shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR)
print(f"Creating {len(class_names)**2} subdirectories for sorting...")
# Create a nested folder structure
for true_style in class_names:
for pred_style in class_names:
os.makedirs(os.path.join(OUTPUT_DIR, f'actual_{true_style}', f'predicted_{pred_style}'), exist_ok=True)
print("Directories created successfully.")
# Prediction and sorting loop
all_preds = []
all_labels = []
print("\n--- Starting Prediction and Sorting ---")
with torch.no_grad():
pbar_test = tqdm(test_loader, desc="Predicting & Sorting")
for inputs, labels, paths in pbar_test:
inputs = inputs.to(device)
outputs = model(pixel_values=inputs).logits
_, preds = torch.max(outputs, 1)
labels_cpu = labels.cpu().numpy()
preds_cpu = preds.cpu().numpy()
for i in range(len(paths)):
original_path = paths[i]
true_label_idx = labels_cpu[i]
pred_label_idx = preds_cpu[i]
true_style_name = class_names[true_label_idx]
pred_style_name = class_names[pred_label_idx]
dest_folder = os.path.join(OUTPUT_DIR, f'actual_{true_style_name}', f'predicted_{pred_style_name}')
shutil.copy(original_path, dest_folder)
all_preds.extend(preds_cpu)
all_labels.extend(labels_cpu)
# Create confusion matrix
accuracy = np.mean(np.array(all_preds) == np.array(all_labels))
print("\n--- Prediction & Sorting Finished ---")
print(f"Test Accuracy: {accuracy * 100:.2f}%")
print(f"Sorted images can be found in the '{OUTPUT_DIR}' directory.")
print("\n--- Creating confusion matrix ---")
cm = confusion_matrix(all_labels, all_preds, labels=range(len(class_names)))
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Label', fontsize=14)
plt.ylabel('True Label', fontsize=14)
plt.title('Confusion Matrix', fontsize=16)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
confusion_matrix_path = 'confusion_matrix.png'
plt.savefig(confusion_matrix_path)
print(f"Confusion matrix has been saved to: {confusion_matrix_path}")
plt.show()
# Gradcam
import torch
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoImageProcessor
import os
from PIL import Image
import numpy as np
from tqdm import tqdm
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import euclidean
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import random
import shutil
# Gradcam tool
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
# Config
MODEL_CHOICE = 'resnet50'
BEST_MODEL_PATH = f'./{MODEL_CHOICE}_best_model'
SORTED_DATA_DIR = 'sorted_predictions'
ROMANTICISM_STYLE = "Romanticism"
REALISM_STYLE = "Realism"
RANDOM_STATE = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Loading model and preparing for feature extraction
full_model = AutoModelForImageClassification.from_pretrained(BEST_MODEL_PATH)
full_model.to(device)
full_model.eval()
tsne_model = AutoModelForImageClassification.from_pretrained(BEST_MODEL_PATH)
tsne_model.classifier = torch.nn.Identity()
print("Final classification layer (fc layer) removed for T-SNE.")
tsne_model.to(device)
tsne_model.eval()
image_processor = AutoImageProcessor.from_pretrained(BEST_MODEL_PATH)
# Grad-CAM tool
def get_target_layer(model, model_choice):
if model_choice == 'resnet50':
return [model.resnet.encoder.stages[-1]]
raise NotImplementedError(f"Warning! Received: {model_choice}")
class HFModelWrapper(torch.nn.Module):
def __init__(self, model):
super(HFModelWrapper, self).__init__()
self.model = model
def forward(self, x):
return self.model(pixel_values=x).logits
def generate_gradcam_visualization(model, image_processor, model_choice, image_path, true_label, pred_label, output_dir):
device = next(model.parameters()).device
os.makedirs(output_dir, exist_ok=True)
target_layers = get_target_layer(model, model_choice)
wrapped_model = HFModelWrapper(model)
cam = GradCAM(model=wrapped_model, target_layers=target_layers)
size = image_processor.size.get('shortest_edge', 224)
original_image = Image.open(image_path).convert('RGB')
rgb_img = np.array(original_image.resize((size, size))) / 255.0
input_tensor = image_processor(original_image, return_tensors="pt").pixel_values.to(device)
pred_idx = full_model.config.label2id[pred_label]
targets = [ClassifierOutputTarget(pred_idx)]
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(original_image.resize((size, size)))
axes[0].set_title(f"Original Image\nTrue Style: {true_label}")
axes[0].axis("off")
axes[1].imshow(visualization)
axes[1].set_title(f"Grad-CAM Heatmap\nModel Predicted: {pred_label}")
axes[1].axis("off")
plt.tight_layout()
base_filename = os.path.basename(image_path)
safe_filename = "".join([c for c in base_filename if c.isalpha() or c.isdigit() or c in ('.', '_')]).rstrip()
save_path = os.path.join(output_dir, f"TRUE_{true_label}_PRED_{pred_label}_{safe_filename}")
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" - Visualization result has been saved to: {save_path}")
# Applying Grad-Cam to misclassified pictures
MISCLASSIFIED_STYLE = 'Romanticism'
PREDICTED_STYLE = 'Realism'
MISCLASSIFIED_IMAGE_DIR = f'sorted_predictions/actual_{MISCLASSIFIED_STYLE}/predicted_{PREDICTED_STYLE}'
MAX_VISUALIZATIONS = 90
OUTPUT_VISUALS_DIR = f'gradcam_visualizations_actual_{MISCLASSIFIED_STYLE}_pred_{PREDICTED_STYLE}'
OUTPUT_ORIGINALS_DIR = f'original_images_actual_{MISCLASSIFIED_STYLE}_pred_{PREDICTED_STYLE}'
for d in [OUTPUT_VISUALS_DIR, OUTPUT_ORIGINALS_DIR]:
if os.path.exists(d):
shutil.rmtree(d)
os.makedirs(d, exist_ok=True)
if not os.path.isdir(MISCLASSIFIED_IMAGE_DIR):
print(f"Error: The directory '{MISCLASSIFIED_IMAGE_DIR}' does not exist.")
exit()
all_misclassified_paths = [os.path.join(MISCLASSIFIED_IMAGE_DIR, f) for f in os.listdir(MISCLASSIFIED_IMAGE_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))]
count = 0
for img_path in all_misclassified_paths:
if count >= MAX_VISUALIZATIONS:
break
print(f"\n Found misclassified sample (No. {count + 1}): {img_path}")
generate_gradcam_visualization(
model=full_model,
image_processor=image_processor,
model_choice=MODEL_CHOICE,
image_path=img_path,
true_label=MISCLASSIFIED_STYLE,
pred_label=PREDICTED_STYLE,
output_dir=OUTPUT_VISUALS_DIR
)
orig_fname = os.path.basename(img_path)
dest_orig = os.path.join(OUTPUT_ORIGINALS_DIR, orig_fname)
with Image.open(img_path) as im:
im = im.convert('RGB')
im_resized = im.resize((224, 224), Image.BILINEAR)
im_resized.save(dest_orig)
count += 1
# Visualization
if count == 0:
print(f"Haven't found any pictures in '{MISCLASSIFIED_IMAGE_DIR}'")
else:
print(f"\nCreated {count} original+heatmap pairs.")
print(f"--- Presenting visualizations from '{OUTPUT_VISUALS_DIR}' ---")
composite_files = sorted([
os.path.join(OUTPUT_VISUALS_DIR, f)
for f in os.listdir(OUTPUT_VISUALS_DIR)
if f.endswith(('.png', '.jpg', '.jpeg'))
])
num = len(composite_files)
if num > 0:
cols = 2
rows = int(np.ceil(num / cols))
fig, axes = plt.subplots(rows, cols, figsize=(10, 5 * rows))
axes = axes.flatten() if rows * cols > 1 else [axes]
for idx, comp_path in enumerate(composite_files):
img = Image.open(comp_path)
axes[idx].imshow(img)
axes[idx].set_title(os.path.basename(comp_path), fontsize=10, wrap=True)
axes[idx].axis('off')
for j in range(idx + 1, len(axes)):
axes[j].axis('off')
plt.tight_layout(pad=2.0)
plt.show()
else:
print("No visualizations were generated.")
print(f"\nSaved all the original pictures to `{OUTPUT_ORIGINALS_DIR}`")